{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.8.1+cu102'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.distributions as td\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "\n",
    "from collections import namedtuple, OrderedDict, defaultdict\n",
    "from tqdm.auto import tqdm\n",
    "from itertools import chain\n",
    "from tabulate import tabulate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data import load_mnist, Batcher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, valid_loader, test_loader = load_mnist(\n",
    "    100, \n",
    "    save_to='../tmp', \n",
    "    height=28, \n",
    "    width=28\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_batcher(data_loader):\n",
    "    batcher = Batcher(\n",
    "        data_loader, \n",
    "        height=28, \n",
    "        width=28, \n",
    "        device=torch.device('cpu'), \n",
    "        binarize=True, \n",
    "        num_classes=10,\n",
    "        onehot=False\n",
    "    )\n",
    "    return batcher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(55000, 785)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training = np.concatenate([np.concatenate([x_obs.reshape(-1, 28*28).numpy(), c_obs.unsqueeze(1).numpy()], -1) for x_obs, c_obs in get_batcher(train_loader)])\n",
    "training.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5000, 785)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid = np.concatenate([np.concatenate([x_obs.reshape(-1, 28*28).numpy(), c_obs.unsqueeze(1).numpy()], -1) for x_obs, c_obs in get_batcher(valid_loader)])\n",
    "valid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000, 785)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = np.concatenate([np.concatenate([x_obs.reshape(-1, 28*28).numpy(), c_obs.unsqueeze(1).numpy()], -1) for x_obs, c_obs in get_batcher(test_loader)])\n",
    "test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for alg in ['ball_tree', 'kd_tree']:\n",
    "    for k in [1, 5, 10]:\n",
    "\n",
    "        model = KNeighborsClassifier(n_neighbors=k, algorithm=alg, p=2, n_jobs=10)\n",
    "        model.fit(training[:,:-1], training[:,-1])\n",
    "\n",
    "        # evaluate the model and update the accuracies list\n",
    "        score = model.score(valid[:,:-1], valid[:,-1])\n",
    "        print(f\"alg={alg} k={k}, accuracy={score*100:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='kd_tree', n_jobs=10)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = KNeighborsClassifier(n_neighbors=5, algorithm='kd_tree', n_jobs=10)\n",
    "model.fit(training[:,:-1], training[:,-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "         0.0       0.98      0.99      0.98       489\n",
      "         1.0       0.91      1.00      0.95       530\n",
      "         2.0       0.98      0.96      0.97       493\n",
      "         3.0       0.96      0.96      0.96       509\n",
      "         4.0       0.98      0.95      0.96       499\n",
      "         5.0       0.95      0.95      0.95       458\n",
      "         6.0       0.98      0.99      0.98       482\n",
      "         7.0       0.96      0.98      0.97       563\n",
      "         8.0       0.99      0.89      0.94       494\n",
      "         9.0       0.95      0.94      0.95       483\n",
      "\n",
      "    accuracy                           0.96      5000\n",
      "   macro avg       0.96      0.96      0.96      5000\n",
      "weighted avg       0.96      0.96      0.96      5000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "val_pred = model.predict(valid[:,:-1])\n",
    "print(classification_report(valid[:,-1], val_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "         0.0       0.96      0.99      0.98       980\n",
      "         1.0       0.89      1.00      0.94      1135\n",
      "         2.0       0.99      0.93      0.96      1032\n",
      "         3.0       0.95      0.96      0.95      1010\n",
      "         4.0       0.96      0.93      0.95       982\n",
      "         5.0       0.94      0.95      0.94       892\n",
      "         6.0       0.97      0.98      0.98       958\n",
      "         7.0       0.93      0.95      0.94      1028\n",
      "         8.0       0.99      0.87      0.92       974\n",
      "         9.0       0.94      0.93      0.93      1009\n",
      "\n",
      "    accuracy                           0.95     10000\n",
      "   macro avg       0.95      0.95      0.95     10000\n",
      "weighted avg       0.95      0.95      0.95     10000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions = model.predict(test[:,:-1])\n",
    "print(classification_report(test[:,-1],predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "# save the model to disk\n",
    "pickle.dump(model, open('knnclassifier.pickle', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9628\n"
     ]
    }
   ],
   "source": [
    "# load the model from disk\n",
    "loaded_model = pickle.load(open('knnclassifier.pickle', 'rb'))\n",
    "result = loaded_model.score(valid[:,:-1], valid[:,-1])\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
